SR Infrence¶

use trained weights to super-resolution a new lores depth and hires DEM pair of arbitrary shape.

Also compute perofrmance metrics against a validation depth raster

In [1]:
# Standard library + scientific stack used throughout inference notebook.
from pathlib import Path
import math
import json

import numpy as np
import pandas as pd
import rasterio


import matplotlib.pyplot as plt
In [2]:
# Reuse shared preprocessing and diagnostics helpers from training code.
from t02.train import (
    extract_sr_dem_prediction_np,
    invert_depth_log1p_np,
    normalize_dem,
    prepare_sr_dem_model_inputs_np,
    resize_bilinear_2d_np,
    scale_depth_log1p_np,
)
import t02.results as results
2026-02-06 14:35:25.309164: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-06 14:35:25.309225: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-06 14:35:25.310300: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

define raster paths

In [3]:
# Input rasters for DEM-conditioned super-resolution inference.
dem_fp = Path('_inputs/RSSHydro/dudelange/002/DEM.tif')
depth_lores_fp = Path('_inputs/RSSHydro/dudelange/032/ResultA.tif')
depth_hires_valid_fp = Path('_inputs/RSSHydro/dudelange/002/ResultA.tif')

assert dem_fp.exists(), f"DEM file not found: {dem_fp}"
assert depth_lores_fp.exists(), f"Lo-res depth file not found: {depth_lores_fp}"
assert depth_hires_valid_fp.exists(), f"Hi-res validation depth file not found: {depth_hires_valid_fp}"
In [4]:
# Core run configuration:
# - tile geometry
# - windowing strategy
# - preprocessing constants aligned with training
SCALE = 4  # amount of super-resolution to perform. Must be 4 for this model.
LR_TILE = 64
HR_TILE = LR_TILE * SCALE
WINDOW_METHOD = "feather"  # "hard" (no overlap) or "feather" (overlap + feather blending)
FEATHER_OVERLAP_LR = LR_TILE // 4  # overlap size at LR scale (used when WINDOW_METHOD == "feather")

MAX_DEPTH = 5.0  # training clip used in SR_ResUNet_DEM.ipynb
DEM_PCT_CLIP = 95.0
DEM_REF_STATS = None  # optional train DEM normalization stats loaded from train_config.json

DRY_DEPTH_THRESH_M = results.DEFAULT_DRY_DEPTH_THRESH_M  # keep dry/wet threshold consistent with training diagnostics

assert MAX_DEPTH > 0, f"MAX_DEPTH must be > 0; got {MAX_DEPTH}"
assert WINDOW_METHOD in {"hard", "feather"}, f"Unsupported WINDOW_METHOD: {WINDOW_METHOD!r}"
assert 0 <= FEATHER_OVERLAP_LR < LR_TILE, (
    f"FEATHER_OVERLAP_LR must be in [0, {LR_TILE}); got {FEATHER_OVERLAP_LR}"
)
if WINDOW_METHOD == "feather":
    assert FEATHER_OVERLAP_LR > 0, "FEATHER_OVERLAP_LR must be > 0 when WINDOW_METHOD='feather'"

model weights¶

see t02/SR_ResUNet_DEM.ipynb for training pipeline

In [5]:
# Load inference artifact and preprocessing metadata.
# train.py exports `model_infer.keras` *after* loading best validation-loss weights,
# so this file can be treated as the best checkpoint for inference.
model_fp = Path('t02/train/full/02/model_infer.keras')
assert model_fp.exists(), f"Model file not found: {model_fp}"

train_config_fp = model_fp.parent / 'train_config.json'
if train_config_fp.exists():
    train_cfg = json.loads(train_config_fp.read_text())

    if 'max_depth' in train_cfg:
        MAX_DEPTH = float(train_cfg['max_depth'])
    if 'dem_pct_clip' in train_cfg:
        DEM_PCT_CLIP = float(train_cfg['dem_pct_clip'])
    if 'dem_stats' in train_cfg and train_cfg['dem_stats'] is not None:
        DEM_REF_STATS = {k: float(v) for k, v in train_cfg['dem_stats'].items()}

    assert MAX_DEPTH > 0, f"MAX_DEPTH must be > 0; got {MAX_DEPTH}"
    assert 0 < DEM_PCT_CLIP <= 100, f"DEM_PCT_CLIP must be in (0, 100]; got {DEM_PCT_CLIP}"
    if DEM_REF_STATS is not None:
        required = {'p_clip', 'dem_min', 'dem_max'}
        missing = required.difference(DEM_REF_STATS.keys())
        if missing:
            raise AssertionError(f"dem_stats is missing keys: {sorted(missing)}")

    print(f"Loaded train config from {train_config_fp}")
    print(f"Using MAX_DEPTH={MAX_DEPTH}, DEM_PCT_CLIP={DEM_PCT_CLIP}")
    if DEM_REF_STATS is not None:
        print(f"Using train DEM stats: {DEM_REF_STATS}")
else:
    print(f"train_config.json not found at {train_config_fp}; using notebook defaults for preprocessing.")
Loaded train config from t02/train/full/02/train_config.json
Using MAX_DEPTH=5.0, DEM_PCT_CLIP=95.0
Using train DEM stats: {'dem_max': 1036.0579833984375, 'dem_min': 176.46800231933594, 'p_clip': 1036.0579833984375}

Input Checks¶

In [6]:
# Raster I/O and validation helpers.
# These checks catch geospatial mismatches before model inference.
def _assert_square_pixels(src, name):
    res = src.res
    if not np.isclose(abs(res[0]), abs(res[1])):
        raise AssertionError(f"{name} pixels are not square: res={res}")


def _assert_bounds_close(a, b, name_a, name_b, tol=1e-6):
    if not all(np.isclose(av, bv, atol=tol) for av, bv in zip(a, b)):
        raise AssertionError(f"{name_a} bounds {a} != {name_b} bounds {b}")


def _assert_no_nodata(arr, name):
    if np.ma.isMaskedArray(arr) and np.any(arr.mask):
        raise AssertionError(f"{name} contains nodata/masked values")
    if np.isnan(np.asarray(arr)).any():
        raise AssertionError(f"{name} contains NaNs")


def read_and_check_rasters(dem_fp, depth_lores_fp, depth_hires_fp):
    with rasterio.open(dem_fp) as dem_src, \
        rasterio.open(depth_lores_fp) as lr_src, \
        rasterio.open(depth_hires_fp) as hr_src:
        # CRS and bounds checks: all inputs must be on the same grid/extent.
        if dem_src.crs != hr_src.crs or lr_src.crs != hr_src.crs:
            raise AssertionError("CRS mismatch between rasters")
        _assert_bounds_close(dem_src.bounds, hr_src.bounds, "DEM", "HR")
        _assert_bounds_close(lr_src.bounds, hr_src.bounds, "LR", "HR")

        _assert_square_pixels(dem_src, "DEM")
        _assert_square_pixels(lr_src, "LR")
        _assert_square_pixels(hr_src, "HR")

        # Read single-band rasters as masked arrays first to detect nodata explicitly.
        dem = dem_src.read(1, masked=True)
        lr = lr_src.read(1, masked=True)
        hr = hr_src.read(1, masked=True)

        _assert_no_nodata(dem, "DEM")
        _assert_no_nodata(lr, "LR depth")
        _assert_no_nodata(hr, "HR depth")

        dem = np.asarray(dem)
        lr = np.asarray(lr)
        hr = np.asarray(hr)

        if np.min(lr) < 0 or np.min(hr) < 0:
            raise AssertionError("Depth rasters contain negative values")
        if np.max(lr) > 15 or np.max(hr) > 15:
            raise AssertionError("Depth rasters exceed 15m max depth")
        if np.min(dem) < 0 or np.max(dem) > 5000:
            raise AssertionError("DEM raster outside expected range [0, 5000]")

        # scale sanity (warn only): model expects a fixed LR->HR ratio.
        res_ratio = lr_src.res[0] / hr_src.res[0]
        if not np.isclose(res_ratio, SCALE):
            print(f"WARNING: LR/HR resolution ratio {res_ratio:.2f} != SCALE={SCALE}. ")
            print("         LR will be resampled to match model scale.")

        return dem, lr, hr, dem_src.profile, lr_src.profile, hr_src.profile
    
# Read + validate rasters
dem_raw, lr_raw, hr_raw, dem_profile, lr_profile, hr_profile = read_and_check_rasters(
    dem_fp, depth_lores_fp, depth_hires_valid_fp
)

print("Raw shapes (HR, LR):", hr_raw.shape, lr_raw.shape)
WARNING: LR/HR resolution ratio 10.00 != SCALE=4. 
         LR will be resampled to match model scale.
Raw shapes (HR, LR): (2030, 2090) (203, 209)

plot raw inputs¶

In [7]:
# Quick sanity plots for raw (un-normalized) inputs before preprocessing.

# Input diagnostics (raw): 2 columns (histogram, raster) x 3 rows (LR depth, HR depth, DEM)
plot_specs_raw = [
    ("LR depth (raw)", lr_raw, "viridis", True, DRY_DEPTH_THRESH_M),
    ("HR depth (raw)", hr_raw, "viridis", True, DRY_DEPTH_THRESH_M),
    ("DEM (raw)", dem_raw, "terrain", False, None),
]

fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 12))

for row_idx, (title, arr, cmap, use_dry_mask, dry_thresh) in enumerate(plot_specs_raw):
    arr = np.asarray(arr)
    vals = arr[np.isfinite(arr)]

    ax_hist = axes[row_idx, 0]
    ax_raster = axes[row_idx, 1]

    ax_hist.hist(vals, bins=60, color='steelblue', alpha=0.9)
    if use_dry_mask:
        ax_hist.axvline(dry_thresh, color='red', linestyle='--', linewidth=1.5)
    ax_hist.set_title(f"{title} histogram")
    ax_hist.set_xlabel('Value')
    ax_hist.set_ylabel('Count')
    ax_hist.grid(color='lightgrey', linestyle='-', linewidth=0.7)

    #add some stats
    ax_hist.text(0.98, 0.95, f"shape: {arr.shape}\nmin: {vals.min():.3f}\nmax: {vals.max():.3f}\nmean: {vals.mean():.3f}\nstd: {vals.std():.3f}",
            transform=ax_hist.transAxes, fontsize=9, verticalalignment='top', horizontalalignment='right')
    
    raster_arr = np.ma.masked_where(arr < dry_thresh, arr) if use_dry_mask else arr
    im = ax_raster.imshow(raster_arr, cmap=cmap)
    ax_raster.set_title(f"{title} raster")
    ax_raster.set_axis_off()
    fig.colorbar(im, ax=ax_raster, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()
No description has been provided for this image

pre-process validation raster¶

because our model was trained with a scale of 4, we need need to pre-process to fit model constraints

In [8]:
# Use CPU for deterministic notebook inference/debugging.
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
In [9]:
# Preprocess rasters into model-ready normalized tensors.
# Crop HR/DEM to be divisible by SCALE (avoids edge mismatches).
hr_h, hr_w = hr_raw.shape
crop_h = hr_h - (hr_h % SCALE)
crop_w = hr_w - (hr_w % SCALE)
if (crop_h, crop_w) != (hr_h, hr_w):
    print(f"Cropping HR/DEM from {(hr_h, hr_w)} to {(crop_h, crop_w)} for SCALE alignment")
hr_raw = hr_raw[:crop_h, :crop_w]
dem_raw = dem_raw[:crop_h, :crop_w]

# Depth normalization (clip -> log1p scale).
# This must match training exactly so model inputs are in-distribution.
sat_lr = float(np.mean(lr_raw >= MAX_DEPTH))
sat_hr = float(np.mean(hr_raw >= MAX_DEPTH))
print(f"Depth saturation @ MAX_DEPTH: LR={sat_lr:.2%}, HR={sat_hr:.2%}")

lr_raw = scale_depth_log1p_np(lr_raw, max_depth=MAX_DEPTH)
hr_raw = scale_depth_log1p_np(hr_raw, max_depth=MAX_DEPTH)

# DEM normalization: clip negatives, cap upper percentile, then min-max scale.
# Optionally reuse training DEM stats from train_config.json for consistency.
dem_norm, dem_stats = normalize_dem(dem_raw, pct_clip=DEM_PCT_CLIP, ref_stats=DEM_REF_STATS)
assert dem_norm is not None and dem_stats is not None
if DEM_REF_STATS is None:
    print("DEM stats computed from inference raster:", dem_stats)
else:
    print("DEM stats reused from train_config:", dem_stats)

# Resample LR depth to match HR/SCALE spatial size (bilinear).
# Resulting LR grid aligns with expected model input tile shape.
target_lr_h = crop_h // SCALE
target_lr_w = crop_w // SCALE
lr_norm = resize_bilinear_2d_np(lr_raw, (target_lr_h, target_lr_w), antialias=True)
lr_norm = np.clip(lr_norm, 0.0, 1.0)

assert dem_norm.shape == hr_raw.shape
assert lr_norm.shape == (target_lr_h, target_lr_w)

# Pad to tile multiples for windowed inference.
# Padding avoids partial edge tiles; we crop back after prediction.
pad_h = (int(math.ceil(crop_h / HR_TILE)) * HR_TILE) - crop_h
pad_w = (int(math.ceil(crop_w / HR_TILE)) * HR_TILE) - crop_w

hr_pad = np.pad(hr_raw, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=0.0)
dem_pad = np.pad(dem_norm, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=0.0)
lr_pad = np.pad(
    lr_norm,
    ((0, pad_h // SCALE), (0, pad_w // SCALE)),
    mode="constant",
    constant_values=0.0,
)

print("Padded shapes HR/LR:", hr_pad.shape, lr_pad.shape)
Cropping HR/DEM from (2030, 2090) to (2028, 2088) for SCALE alignment
Depth saturation @ MAX_DEPTH: LR=0.00%, HR=0.00%
DEM stats reused from train_config: {'p_clip': 1036.0579833984375, 'dem_min': 176.46800231933594, 'dem_max': 1036.0579833984375}
Padded shapes HR/LR: (2048, 2304) (512, 576)

plot normalized inputs¶

In [10]:
# Sanity plots after normalization to verify model-ready ranges/patterns.
# Input diagnostics: 2 columns (histogram, raster) x 3 rows (LR depth, HR depth, DEM)
dry_thresh_norm = float(scale_depth_log1p_np(np.array([DRY_DEPTH_THRESH_M], dtype=np.float32), max_depth=MAX_DEPTH)[0])
plot_specs = [
    ("LR depth (normalized)", lr_norm, "viridis", True, dry_thresh_norm),
    ("HR depth (normalized)", hr_raw, "viridis", True, dry_thresh_norm),
    ("DEM (normalized)", dem_norm, "terrain", False, None),
]

fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 12))

for row_idx, (title, arr, cmap, use_dry_mask, dry_thresh) in enumerate(plot_specs):
    arr = np.asarray(arr)
    vals = arr[np.isfinite(arr)]

    ax_hist = axes[row_idx, 0]
    ax_raster = axes[row_idx, 1]

    ax_hist.hist(vals, bins=60, color='steelblue', alpha=0.9)
    if use_dry_mask:
        ax_hist.axvline(dry_thresh, color='red', linestyle='--', linewidth=1.5)
    ax_hist.set_title(f"{title} histogram")
    ax_hist.set_xlabel('Value')
    ax_hist.set_ylabel('Count')
    ax_hist.grid(color='lightgrey', linestyle='-', linewidth=0.7)

    #add some stats
    ax_hist.text(0.98, 0.95, f"shape: {arr.shape}\nmin: {vals.min():.3f}\nmax: {vals.max():.3f}\nmean: {vals.mean():.3f}\nstd: {vals.std():.3f}",
            transform=ax_hist.transAxes, fontsize=9, verticalalignment='top', horizontalalignment='right')


    raster_arr = np.ma.masked_where(arr < dry_thresh, arr) if use_dry_mask else arr
    im = ax_raster.imshow(raster_arr, cmap=cmap)
    ax_raster.set_title(f"{title} raster")
    ax_raster.set_axis_off()
    fig.colorbar(im, ax=ax_raster, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()
No description has been provided for this image

Per-chip inference first, then mosaicing¶

Run simple non-overlap chip inference for diagnostics, render test-chip style plots, then build the final mosaic while reusing cached chip predictions whenever possible.

In [11]:
# Load the exported inference model directly.
# This artifact already contains best validation-loss weights (see train.py export logic).
model = tf.keras.models.load_model(model_fp, compile=False)
model.trainable = False
print(f"Loaded model (best-weight export): {model_fp}")
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA RTX 4000 Ada Generation Laptop GPU, compute capability 8.9
Loaded model (best-weight export): t02/train/full/02/model_infer.keras
In [12]:
# Run per-chip inference first (no mosaicing), then build the full-scene mosaic.
hr_pad_h, hr_pad_w = hr_pad.shape


# Generate tile start positions and force coverage of the trailing edge.
def build_tile_starts(total_size, tile_size, stride):
    starts = list(range(0, max(total_size - tile_size + 1, 1), stride))
    last_start = total_size - tile_size
    if starts[-1] != last_start:
        starts.append(last_start)
    return starts


# Tile-prediction cache keyed by HR-grid tile origin (y0, x0).
tile_pred_cache = {}


# Predict one chip and cache it so feathered mosaicing can reuse non-overlap results.
def _predict_tile(y0, x0):
    key = (int(y0), int(x0))
    if key in tile_pred_cache:
        return tile_pred_cache[key]

    lr_y0 = y0 // SCALE
    lr_x0 = x0 // SCALE

    lr_tile = lr_pad[lr_y0 : lr_y0 + LR_TILE, lr_x0 : lr_x0 + LR_TILE]
    dem_tile = dem_pad[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE]

    lr_tile_batched, dem_tile_batched = prepare_sr_dem_model_inputs_np(
        depth_lr_norm=lr_tile,
        dem_hr_norm=dem_tile,
        expected_lr_shape=(LR_TILE, LR_TILE, 1),
        expected_hr_shape=(HR_TILE, HR_TILE, 1),
    )

    pred = model((lr_tile_batched, dem_tile_batched), training=False)
    pred_np = extract_sr_dem_prediction_np(pred, expected_hr_shape=(HR_TILE, HR_TILE, 1))
    tile_pred_cache[key] = pred_np
    return pred_np
In [13]:
# Step 1: simple non-overlap per-chip inference for diagnostics.
nonoverlap_y_starts = list(range(0, hr_pad_h, HR_TILE))
nonoverlap_x_starts = list(range(0, hr_pad_w, HR_TILE))
sr_pad_nonoverlap = np.zeros_like(hr_pad, dtype=np.float32)

print(
    f"Running simple non-overlap per-chip inference on "
    f"{len(nonoverlap_y_starts) * len(nonoverlap_x_starts)} chips..."
)
for y0 in nonoverlap_y_starts:
    for x0 in nonoverlap_x_starts:
        pred_np = _predict_tile(y0, x0)
        sr_pad_nonoverlap[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE] = pred_np


# Build withheld-test-like chip stacks from fully valid (non-padded) chip locations.
valid_tiles_h = crop_h // HR_TILE
valid_tiles_w = crop_w // HR_TILE
if valid_tiles_h == 0 or valid_tiles_w == 0:
    raise ValueError(
        f"No fully valid chips for diagnostics (crop={(crop_h, crop_w)}, HR_TILE={HR_TILE})."
    )

n_valid = valid_tiles_h * valid_tiles_w
lowres_chips = np.zeros((n_valid, LR_TILE, LR_TILE, 1), dtype=np.float32)
highres_chips = np.zeros((n_valid, HR_TILE, HR_TILE, 1), dtype=np.float32)
preds_chips = np.zeros((n_valid, HR_TILE, HR_TILE, 1), dtype=np.float32)
chip_coords = []

chip_idx = 0
for ty in range(valid_tiles_h):
    y0 = ty * HR_TILE
    for tx in range(valid_tiles_w):
        x0 = tx * HR_TILE

        lr_y0 = y0 // SCALE
        lr_x0 = x0 // SCALE

        lowres_chips[chip_idx, ..., 0] = lr_pad[lr_y0 : lr_y0 + LR_TILE, lr_x0 : lr_x0 + LR_TILE]
        highres_chips[chip_idx, ..., 0] = hr_pad[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE]
        preds_chips[chip_idx, ..., 0] = tile_pred_cache[(y0, x0)]
        chip_coords.append((y0, x0))
        chip_idx += 1

chip_coords = np.asarray(chip_coords, dtype=np.int32)
print(f"Prepared {chip_idx} valid chips ({valid_tiles_h} x {valid_tiles_w}) for diagnostics.")
Running simple non-overlap per-chip inference on 72 chips...
Prepared 56 valid chips (7 x 8) for diagnostics.
In [14]:
# Compute per-chip summary + per-sample metrics (model vs bilinear baseline).
chip_summary, chip_per_sample = results.evaluate_chip_arrays_vs_bilinear(
    lowres_chips=lowres_chips,
    highres_chips=highres_chips,
    preds_chips=preds_chips,
    max_depth=MAX_DEPTH,
    split_name='inference_chips',
    dry_depth_thresh_m=DRY_DEPTH_THRESH_M,
)

GLOBAL_METRICS = {'test': chip_summary}
PER_SAMPLE_METRICS = {'test': chip_per_sample}

print('Per-chip GLOBAL_METRICS[test]:')
print(json.dumps(GLOBAL_METRICS['test'], indent=2, sort_keys=True))


# Reproduce training_results.ipynb diagnostics on withheld-style chips.
test_samples = PER_SAMPLE_METRICS['test']

fig_scatter, _ = results.plot_metric_scatter_vs_mean_depth(test_samples)
plt.show()
plt.close(fig_scatter)
Per-chip GLOBAL_METRICS[test]:
{
  "baseline": {
    "CSI": 0.21588961780071259,
    "MAE": 0.017262272536754608,
    "PSNR": 27.65399169921875,
    "RMSE": 0.0518011711537838,
    "RMSE_wet": 0.42853009700775146,
    "SSIM": 0.8069341778755188
  },
  "best_epoch": {
    "CSI": 0.21867890655994415,
    "MAE": 0.01664661057293415,
    "PSNR": 27.6664981842041,
    "RMSE": 0.05031589791178703,
    "RMSE_wet": 0.45040664076805115,
    "SSIM": 0.8093225359916687
  }
}
No description has been provided for this image
In [15]:
fig_chip_scatter, _ = results.plot_chip_stat_scatter(test_samples)
plt.show()
plt.close(fig_chip_scatter)
No description has been provided for this image
In [16]:
_ = results.plot_best_worst_chip_examples(
    lowres_chips=lowres_chips,
    highres_chips=highres_chips,
    preds_chips=preds_chips,
    max_depth=MAX_DEPTH,
    n_show=3,
    dry_depth_thresh_m=DRY_DEPTH_THRESH_M,
    cmap='cividis',
    chip_ids=chip_coords,
)
Scanned test chips: 56
Retained candidate chips in memory: 6

Worst chips (highest SR MAE):
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Best chips (lowest SR MAE):
/usr/local/lib/python3.11/dist-packages/numpy/lib/histograms.py:885: RuntimeWarning: divide by zero encountered in divide
  return n/db/n.sum(), bin_edges
/usr/local/lib/python3.11/dist-packages/numpy/lib/histograms.py:885: RuntimeWarning: invalid value encountered in divide
  return n/db/n.sum(), bin_edges
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [17]:
# Step 2: build final mosaic, reusing cached chip predictions whenever possible.
if WINDOW_METHOD == 'hard':
    sr_pad = sr_pad_nonoverlap.copy()
    print(
        f"Mosaicing with WINDOW_METHOD='hard' using cached chips only "
        f"({len(tile_pred_cache)} predictions cached)."
    )

elif WINDOW_METHOD == 'feather':
    overlap_hr = FEATHER_OVERLAP_LR * SCALE
    stride_hr = HR_TILE - overlap_hr

    if stride_hr <= 0:
        raise AssertionError(
            f"Feather stride must be > 0; got stride_hr={stride_hr} from FEATHER_OVERLAP_LR={FEATHER_OVERLAP_LR}"
        )

    y_starts = build_tile_starts(hr_pad_h, HR_TILE, stride_hr)
    x_starts = build_tile_starts(hr_pad_w, HR_TILE, stride_hr)

    feather_1d = np.ones(HR_TILE, dtype=np.float32)
    if overlap_hr > 0:
        ramp = np.linspace(0.0, 1.0, overlap_hr + 2, dtype=np.float32)[1:-1]
        feather_1d[:overlap_hr] = ramp
        feather_1d[-overlap_hr:] = ramp[::-1]
    feather_1d = np.clip(feather_1d, 1e-3, 1.0)

    accum = np.zeros_like(hr_pad, dtype=np.float32)
    weight_sum = np.zeros_like(hr_pad, dtype=np.float32)

    print(
        f"Mosaicing with WINDOW_METHOD='feather' using {len(y_starts) * len(x_starts)} windows "
        f"(overlap={overlap_hr} px, stride={stride_hr} px)..."
    )

    for yi, y0 in enumerate(y_starts):
        for xi, x0 in enumerate(x_starts):
            pred_np = _predict_tile(y0, x0)

            wy = feather_1d.copy()
            wx = feather_1d.copy()

            if yi == 0:
                wy[:overlap_hr] = 1.0
            if yi == len(y_starts) - 1:
                wy[-overlap_hr:] = 1.0
            if xi == 0:
                wx[:overlap_hr] = 1.0
            if xi == len(x_starts) - 1:
                wx[-overlap_hr:] = 1.0

            weight = np.outer(wy, wx).astype(np.float32, copy=False)
            accum[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE] += pred_np * weight
            weight_sum[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE] += weight

    sr_pad = np.divide(
        accum,
        np.maximum(weight_sum, 1e-6),
        out=np.zeros_like(accum),
        where=weight_sum > 0,
    )
    print(f"Cached predictions after feather mosaicing: {len(tile_pred_cache)}")

else:
    raise ValueError(f"Unsupported WINDOW_METHOD: {WINDOW_METHOD!r}")


# Remove padding and keep normalized prediction range bounded.
sr = np.clip(sr_pad[:crop_h, :crop_w], 0.0, 1.0)
hr_valid = hr_raw[:crop_h, :crop_w]
print('SR shape:', sr.shape)
Mosaicing with WINDOW_METHOD='feather' using 132 windows (overlap=64 px, stride=192 px)...
Cached predictions after feather mosaicing: 188
SR shape: (2028, 2088)
In [18]:
# Export prediction as GeoTIFF in depth units (meters).
ofp = Path('t02/inference/sr_result.tif')
ofp.parent.mkdir(parents=True, exist_ok=True)

# Convert model output from normalized log-space back to depth (meters).
sr_depth_m = invert_depth_log1p_np(sr, max_depth=MAX_DEPTH)

print(f"Saving SR result to {ofp}...")
out_profile = hr_profile.copy()
out_profile.update(
    dtype=rasterio.float32,
    count=1,
    compress='deflate',
    height=sr_depth_m.shape[0],
    width=sr_depth_m.shape[1],
)
with rasterio.open(ofp, 'w', **out_profile) as dst:
    dst.write(sr_depth_m.astype(np.float32), 1)

print(
    f"Saved depth raster in meters. Range: min={float(np.nanmin(sr_depth_m)):.4f}, max={float(np.nanmax(sr_depth_m)):.4f}"
)
Saving SR result to t02/inference/sr_result.tif...
Saved depth raster in meters. Range: min=0.0000, max=4.4980

Bilinear baseline + mosaic-level metrics¶

In [19]:
# Build a bilinear baseline from LR input and compare against HR target.
baseline_pad = resize_bilinear_2d_np(lr_pad, (hr_pad_h, hr_pad_w), antialias=True)
baseline = np.clip(baseline_pad[:crop_h, :crop_w], 0.0, 1.0)


# Compute mosaic-level metrics in normalized space with shared helpers.
hr_full = tf.convert_to_tensor(hr_valid[None, ..., None], dtype=tf.float32)
sr_full = tf.convert_to_tensor(sr[None, ..., None], dtype=tf.float32)
bl_full = tf.convert_to_tensor(baseline[None, ..., None], dtype=tf.float32)

sr_metric_tensors = results.compute_per_sample_metrics(hr_full, sr_full)
bl_metric_tensors = results.compute_per_sample_metrics(hr_full, bl_full)

metrics_sr = results.reduce_metric_buffers({k: [v] for k, v in sr_metric_tensors.items()})
metrics_bilinear = results.reduce_metric_buffers({k: [v] for k, v in bl_metric_tensors.items()})

df = pd.DataFrame({
    'ResUNet': metrics_sr,
    'Bilinear': metrics_bilinear,
})
df = df.loc[list(results.METRIC_KEYS), ['ResUNet', 'Bilinear']]
df.round(4)
Out[19]:
ResUNet Bilinear
MAE 0.0151 0.0156
PSNR 25.2495 24.4981
SSIM 0.8264 0.8244
RMSE 0.0546 0.0596
RMSE_wet 0.4095 0.3185
CSI 0.2944 0.2879

PLOT Inference (mosaic-level)¶

In [20]:
# Final full-scene diagnostics in depth units (meters).
# Plot final full-scene inference with shared depth-domain metrics and histograms.

DRY_DEPTH_THRESH_PLOT_M = DRY_DEPTH_THRESH_M

full_lr = tf.convert_to_tensor(lr_norm[..., None], dtype=tf.float32)
full_hr = tf.convert_to_tensor(hr_valid[..., None], dtype=tf.float32)
full_sr = tf.convert_to_tensor(sr[..., None], dtype=tf.float32)

print("Full-scene inference diagnostics")
fig, final_metrics = results.plot_chip_comparison(
    highres=full_hr,
    lowres=full_lr,
    preds=full_sr,
    max_depth=MAX_DEPTH,
    dry_depth_thresh_m=DRY_DEPTH_THRESH_PLOT_M,
    cmap="cividis",
)
plt.show()
plt.close(fig)

tile_label = "full-scene"
print("PSNR between LR and HR image {}: {:.4f}".format(tile_label, final_metrics["lr_psnr"]))
print("SSIM between LR and HR image {}: {:.4f}".format(tile_label, final_metrics["lr_ssim"]))
print("PSNR between HR and SR image {}: {:.4f}".format(tile_label, final_metrics["sr_psnr"]))
print("SSIM between HR and SR image {}: {:.4f}".format(tile_label, final_metrics["sr_ssim"]))
print("MAE between HR and SR image {}: {:.6f} m".format(tile_label, final_metrics["sr_mae_m"]))

final_metrics
Full-scene inference diagnostics
No description has been provided for this image
PSNR between LR and HR image full-scene: 28.4491
SSIM between LR and HR image full-scene: 0.8828
PSNR between HR and SR image full-scene: 29.0922
SSIM between HR and SR image full-scene: 0.8889
MAE between HR and SR image full-scene: 0.037980 m
Out[20]:
{'lr_psnr': 28.449148178100586,
 'lr_ssim': 0.8827940225601196,
 'lr_mae_m': 0.04028255119919777,
 'lr_rmse_m': 0.18902209401130676,
 'lr_rmse_wet_m': 0.5304687023162842,
 'lr_bias_m': 0.010088089853525162,
 'lr_wet_pixel_count': 336221,
 'lr_dry_pixel_count': 3898243,
 'sr_psnr': 29.0921688079834,
 'sr_ssim': 0.8889098763465881,
 'sr_mae_m': 0.03798045590519905,
 'sr_rmse_m': 0.17553409934043884,
 'sr_rmse_wet_m': 0.5603041052818298,
 'sr_bias_m': -0.0012882208684459329,
 'sr_wet_pixel_count': 336221,
 'sr_dry_pixel_count': 3898243,
 'hr_wet_pixel_count': 336221,
 'hr_dry_pixel_count': 3898243,
 'hr_mean_depth_m': 0.03274461627006531,
 'hr_max_depth_m': 5.0,
 'hr_min_depth_m': 0.0,
 'bl_psnr': 28.636943817138672,
 'bl_ssim': 0.8845934271812439,
 'bl_mae_m': 0.039947520941495895,
 'bl_rmse_m': 0.1849791556596756,
 'bl_rmse_wet_m': 0.5248517990112305,
 'bl_bias_m': 0.009580918587744236,
 'bl_wet_pixel_count': 336221,
 'bl_dry_pixel_count': 3898243}